{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b4ee17ad-e38b-4d49-a8d6-0ec910d9e528",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch,sys,os,time,skimage\n",
    "import numpy as np\n",
    "import mcubes,trimesh\n",
    "from tqdm import tqdm\n",
    "from skimage import measure\n",
    "import matplotlib.pyplot as plt\n",
    "from omegaconf import OmegaConf\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "sys.path.append('..')\n",
    "\n",
    "from models.FactorFields import FactorFields \n",
    "\n",
    "from utils import SimpleSampler\n",
    "from dataLoader import dataset_dict\n",
    "\n",
    "device = 'cuda'\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c8a83bea-cfc1-40cd-bbcc-8a79a13949e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@torch.no_grad()\n",
    "def eval_sdf(reso, bbox, chunk=10240):\n",
    "    z = torch.linspace(0, bbox,reso[2])\n",
    "    y = torch.linspace(0, bbox,reso[1])\n",
    "    x = torch.linspace(0, bbox,reso[0])\n",
    "    \n",
    "    coordiantes = torch.empty((reso[2],reso[1],reso[0],3))\n",
    "    coordiantes[...,0], coordiantes[...,1], coordiantes[...,2] = torch.meshgrid((x, y, z), indexing='ij')\n",
    "    res = torch.empty(reso[2]*reso[1]*reso[0])\n",
    "    \n",
    "    count = 0\n",
    "    coordiantes = coordiantes.reshape(-1,3)#/(torch.FloatTensor(reso[::-1])-1)*2-1\n",
    "    coordiantes = torch.split(coordiantes,chunk,dim=0)\n",
    "    for coordiante in tqdm(coordiantes):\n",
    "\n",
    "        feats,_ = model.get_coding(coordiante.to(model.device))\n",
    "        y_recon = model.linear_mat(feats)\n",
    "        \n",
    "        res[count:count+y_recon.shape[0]] = y_recon.cpu().view(-1)\n",
    "        count += y_recon.shape[0]\n",
    "        # res.append(y_recon.cpu())\n",
    "    return res.reshape(*reso)\n",
    "\n",
    "def eval_point(points):\n",
    "    feats,_ = model.get_basis(points.to(device))\n",
    "    return model.linear_mat(feats).squeeze().cpu()\n",
    "\n",
    "def marchcude_to_world(vertices, reso_WHD):\n",
    "    return vertices/(np.array(reso_WHD)-1)*2-1\n",
    "\n",
    "@torch.no_grad()\n",
    "def cal_l1_iou(test_dataset, chunk=10240):\n",
    "    sdf, coordiantes = test_dataset.sdf, test_dataset.coordiante\n",
    "    \n",
    "    sdf_pred = []\n",
    "    for coordiante in torch.split(coordiantes, chunk, dim=0):\n",
    "        feats,_ = model.get_coding(coordiante.to(model.device))\n",
    "        y_recon = model.linear_mat(feats)\n",
    "        \n",
    "        sdf_pred.append(y_recon.cpu())\n",
    "    sdf_pred = torch.cat(sdf_pred)\n",
    "    \n",
    "    l1 = (sdf_pred-sdf).abs().mean()\n",
    "    iou =  torch.sum((sdf>0)&(sdf_pred>0)) / torch.sum(((sdf>0)|(sdf_pred>0)))\n",
    "    return l1, iou\n",
    "\n",
    "avg_pool_3d = torch.nn.AvgPool3d(2, stride=2)\n",
    "upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')\n",
    "@torch.no_grad()\n",
    "def get_surface_sliding(sdf, path=None, resolution=512, grid_boundary=[-1.0, 1.0], return_mesh=False, level=0):\n",
    "    assert resolution % 512 == 0\n",
    "    resN = resolution\n",
    "    cropN = 512\n",
    "    level = 0\n",
    "    N = resN // cropN\n",
    "\n",
    "    grid_min = [grid_boundary[0], grid_boundary[0], grid_boundary[0]]\n",
    "    grid_max = [grid_boundary[1], grid_boundary[1], grid_boundary[1]]\n",
    "    xs = np.linspace(grid_min[0], grid_max[0], N+1)\n",
    "    ys = np.linspace(grid_min[1], grid_max[1], N+1)\n",
    "    zs = np.linspace(grid_min[2], grid_max[2], N+1)\n",
    "\n",
    "\n",
    "    meshes = []\n",
    "    for i in range(N):\n",
    "        for j in range(N):\n",
    "            for k in range(N):\n",
    "                x_min, x_max = xs[i], xs[i+1]\n",
    "                y_min, y_max = ys[j], ys[j+1]\n",
    "                z_min, z_max = zs[k], zs[k+1]\n",
    "\n",
    "                x = np.linspace(x_min, x_max, cropN)\n",
    "                y = np.linspace(y_min, y_max, cropN)\n",
    "                z = np.linspace(z_min, z_max, cropN)\n",
    "\n",
    "                xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')\n",
    "                points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)\n",
    "                \n",
    "                def evaluate(points):\n",
    "                    z = []\n",
    "                    for _, pnts in enumerate(torch.split(points, 100000, dim=0)):\n",
    "                        z.append(-sdf(pnts))\n",
    "                    z = torch.cat(z, axis=0)\n",
    "                    return z\n",
    "            \n",
    "                # construct point pyramids\n",
    "                points = points.reshape(cropN, cropN, cropN, 3).permute(3, 0, 1, 2)\n",
    "                \n",
    "                points_pyramid = [points]\n",
    "                for _ in range(3):            \n",
    "                    points = avg_pool_3d(points[None])[0]\n",
    "                    points_pyramid.append(points)\n",
    "                points_pyramid = points_pyramid[::-1]\n",
    "                \n",
    "                # evalute pyramid with mask\n",
    "                mask = None\n",
    "                threshold = 2 * (x_max - x_min)/cropN * 8\n",
    "                for pid, pts in enumerate(points_pyramid):\n",
    "                    coarse_N = pts.shape[-1]\n",
    "                    pts = pts.reshape(3, -1).permute(1, 0).contiguous()\n",
    "                    \n",
    "                    if mask is None:    \n",
    "                        pts_sdf = evaluate(pts)\n",
    "                    else:                    \n",
    "                        mask = mask.reshape(-1)\n",
    "                        pts_to_eval = pts[mask]\n",
    "                        #import pdb; pdb.set_trace()\n",
    "                        if pts_to_eval.shape[0] > 0:\n",
    "                            pts_sdf_eval = evaluate(pts_to_eval.contiguous())\n",
    "                            pts_sdf[mask] = pts_sdf_eval\n",
    "\n",
    "                    if pid < 3:\n",
    "                        # update mask\n",
    "                        mask = torch.abs(pts_sdf) < threshold\n",
    "                        mask = mask.reshape(coarse_N, coarse_N, coarse_N)[None, None]\n",
    "                        mask = upsample(mask.float()).bool()\n",
    "\n",
    "                        pts_sdf = pts_sdf.reshape(coarse_N, coarse_N, coarse_N)[None, None]\n",
    "                        pts_sdf = upsample(pts_sdf)\n",
    "                        pts_sdf = pts_sdf.reshape(-1)\n",
    "\n",
    "                    threshold /= 2.\n",
    "\n",
    "                z = pts_sdf.detach().cpu().numpy()\n",
    "\n",
    "                if (not (np.min(z) > level or np.max(z) < level)):\n",
    "                    z = z.astype(np.float32)\n",
    "                    verts, faces, normals, values = measure.marching_cubes(\n",
    "                    volume=z.reshape(cropN, cropN, cropN), #.transpose([1, 0, 2]),\n",
    "                    level=level,\n",
    "                    spacing=(\n",
    "                            (x_max - x_min)/(cropN-1),\n",
    "                            (y_max - y_min)/(cropN-1),\n",
    "                            (z_max - z_min)/(cropN-1) ))\n",
    "\n",
    "                    verts = verts + np.array([x_min, y_min, z_min])\n",
    "                    \n",
    "                    meshcrop = trimesh.Trimesh(verts, faces)\n",
    "                    #meshcrop.export(f\"{i}_{j}_{k}.ply\")\n",
    "                    meshes.append(meshcrop)\n",
    "\n",
    "    combined = trimesh.util.concatenate(meshes)\n",
    "\n",
    "    combined.vertices = combined.vertices/grid_boundary[1]*2-1.0\n",
    "    \n",
    "    if return_mesh:\n",
    "        return combined\n",
    "    elif path is not None:\n",
    "        combined.export(f'{path}.ply')  \n",
    "        \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1ad08e70-1d51-4cf1-b549-26b012acf0a3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=====> total parameters:  5342274\n",
      "/vlg-nfs/anpei/Code/NeuBasis/data/mesh//armadillo_8M.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration 09900: loss_dist = 0.00000025: 100%|███████████████████████████████████████████████████| 10000/10000 [00:30<00:00, 325.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.986945) 30.70767855644226\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:39<00:00, 1055.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=====> total parameters:  5342274\n",
      "/vlg-nfs/anpei/Code/NeuBasis/data/mesh//statuette_8M.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration 09900: loss_dist = 0.00001105: 100%|███████████████████████████████████████████████████| 10000/10000 [00:31<00:00, 316.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.966919) 31.576430797576904\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:37<00:00, 1075.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=====> total parameters:  5342274\n",
      "/vlg-nfs/anpei/Code/NeuBasis/data/mesh//dragon_8M.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration 09900: loss_dist = 0.00000032: 100%|███████████████████████████████████████████████████| 10000/10000 [00:30<00:00, 329.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.980855) 30.343759059906006\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:37<00:00, 1070.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=====> total parameters:  5342274\n",
      "/vlg-nfs/anpei/Code/NeuBasis/data/mesh//lucy_8M.npy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration 09900: loss_dist = 0.00000029: 100%|███████████████████████████████████████████████████| 10000/10000 [00:30<00:00, 333.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.983279) 30.012317180633545\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 104858/104858 [01:37<00:00, 1073.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9794994\n"
     ]
    }
   ],
   "source": [
    "\n",
    "base_conf = OmegaConf.load('../configs/defaults.yaml')\n",
    "second_conf = OmegaConf.load('../configs/sdf.yaml')\n",
    "cfg = OmegaConf.merge(base_conf, second_conf)\n",
    "\n",
    "dataset = dataset_dict[cfg.dataset.dataset_name]\n",
    "\n",
    "is_save_mesh = False\n",
    "\n",
    "scores = []\n",
    "for mode in ['8M']:\n",
    "    for scene in ['armadillo','statuette','dragon','lucy']:\n",
    "\n",
    "        cfg.dataset.datadir = f'../data/SDF/{scene}_{mode}.npy'\n",
    "        train_dataset = dataset(cfg.dataset, split='train')\n",
    "        test_dataset = dataset(cfg.dataset, split='test')\n",
    "\n",
    "        batch_size = cfg.training.batch_size\n",
    "        n_iter = cfg.training.n_iters\n",
    "\n",
    "        model = FactorFields(cfg, device)\n",
    "\n",
    "        print(cfg.dataset.datadir)\n",
    "        sdf, coordiantes = train_dataset.sdf.to(device), train_dataset.coordiante.to(device)\n",
    "        trainingSampler = SimpleSampler(len(train_dataset), cfg.training.batch_size)\n",
    "\n",
    "        grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)\n",
    "        optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#\n",
    "\n",
    "\n",
    "        loss_scale = 1.0\n",
    "        lr_factor = 0.1 ** (1 / n_iter)\n",
    "        pbar = tqdm(range(n_iter))\n",
    "        start = time.time()\n",
    "        for iteration in pbar:\n",
    "            loss_scale *= lr_factor\n",
    "\n",
    "\n",
    "            pixel_idx = torch.randint(0,len(train_dataset),(batch_size,))\n",
    "\n",
    "\n",
    "            feats, coeffs = model.get_coding(coordiantes[pixel_idx])\n",
    "            sdf_recon = model.linear_mat(feats)\n",
    "\n",
    "            loss_dist = torch.mean((sdf_recon-sdf[pixel_idx])**2) \n",
    "\n",
    "            loss = loss_dist * loss_scale\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            if iteration%100==0:\n",
    "                pbar.set_description(\n",
    "                            f'Iteration {iteration:05d}:'\n",
    "                            + f' loss_dist = {loss_dist.item():.8f}'\n",
    "                        )\n",
    "\n",
    "            \n",
    "        time_takes = time.time()-start\n",
    "        reso = train_dataset.DHW[::-1]\n",
    "        # sdf_res = eval_sdf([384]*3)\n",
    "        mae,  gIoU = cal_l1_iou(test_dataset)\n",
    "        torch.set_printoptions(precision=6)\n",
    "        scores.append(gIoU)\n",
    "        print(gIoU,time_takes)\n",
    "        np.savetxt(f'../logs/SDF/{scene}.txt',np.array([gIoU.item(),time_takes]))\n",
    "\n",
    "        if is_save_mesh:\n",
    "            _reso = 1024\n",
    "            sdf_res = eval_sdf([_reso]*3,train_dataset.DHW[0])\n",
    "            vertices, triangles, normals, values = skimage.measure.marching_cubes(\n",
    "                    sdf_res.numpy(), level=0.0\n",
    "                )\n",
    "            vertices = marchcude_to_world(vertices, [_reso]*3)\n",
    "            triangles = triangles[...,::-1]\n",
    "            mesh = trimesh.Trimesh(vertices, triangles)\n",
    "            mesh.export(f'../logs/SDF//{scene}.ply');\n",
    "        \n",
    "print(np.mean(scores))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac23a83d-ae78-4186-bcb7-b1d5ab73dd8a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "factorfield",
   "language": "python",
   "name": "factorfield"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
